{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Prepare Data Set\n",
    "\n",
    "First, a data set is loaded. Function `load_data_from_df` automatically saves calculated features to the provided data directory (unless `use_data_saving` is set to `False`). Every next run will use the saved features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import torch\n",
    "os.chdir('src')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RDKit WARNING: [19:38:24] Enabling RDKit 2019.09.3 jupyter extensions\n"
     ]
    }
   ],
   "source": [
    "from featurization.data_utils import load_data_from_df, construct_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "\n",
    "# Formal charges are one-hot encoded to keep compatibility with the pre-trained weights.\n",
    "# If you do not plan to use the pre-trained weights, we recommend to set one_hot_formal_charge to False.\n",
    "X, y = load_data_from_df('../data/freesolv/freesolv.csv', one_hot_formal_charge=True)\n",
    "data_loader = construct_loader(X, y, batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use your data, but the CSV file should contain two columns as shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CN(C)C(=O)c1ccc(cc1)OC</td>\n",
       "      <td>-1.874467</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>CS(=O)(=O)Cl</td>\n",
       "      <td>-0.277514</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>CC(C)C=C</td>\n",
       "      <td>1.465089</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>CCc1cnccn1</td>\n",
       "      <td>-0.428367</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>CCCCCCCO</td>\n",
       "      <td>-0.105855</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   smiles         y\n",
       "0  CN(C)C(=O)c1ccc(cc1)OC -1.874467\n",
       "1            CS(=O)(=O)Cl -0.277514\n",
       "2                CC(C)C=C  1.465089\n",
       "3              CCc1cnccn1 -0.428367\n",
       "4                CCCCCCCO -0.105855"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv('../data/freesolv/freesolv.csv').head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer import make_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "d_atom = X[0][0].shape[1]  # It depends on the used featurization.\n",
    "\n",
    "model_params = {\n",
    "    'd_atom': d_atom,\n",
    "    'd_model': 1024,\n",
    "    'N': 8,\n",
    "    'h': 16,\n",
    "    'N_dense': 1,\n",
    "    'lambda_attention': 0.33, \n",
    "    'lambda_distance': 0.33,\n",
    "    'leaky_relu_slope': 0.1, \n",
    "    'dense_output_nonlinearity': 'relu', \n",
    "    'distance_matrix_kernel': 'exp', \n",
    "    'dropout': 0.0,\n",
    "    'aggregation_type': 'mean'\n",
    "}\n",
    "\n",
    "model = make_model(**model_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Pretrained Weights (optional)\n",
    "\n",
    "If you want to use the pre-trained weights to train your model, **you should not change model parameters in the cell above**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrained_name = '../pretrained_weights.pt'  # This file should be downloaded first (See README.md).\n",
    "pretrained_state_dict = torch.load(pretrained_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_state_dict = model.state_dict()\n",
    "for name, param in pretrained_state_dict.items():\n",
    "    if 'generator' in name:\n",
    "         continue\n",
    "    if isinstance(param, torch.nn.Parameter):\n",
    "        param = param.data\n",
    "    model_state_dict[name].copy_(param)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run Training/Evaluation Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.cuda()\n",
    "for batch in data_loader:\n",
    "    adjacency_matrix, node_features, distance_matrix, y = batch\n",
    "    batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0\n",
    "    output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, None)\n",
    "    ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
